229a7c
@@ -2444,99 +2444,41 @@
public static boolean isEmptyPath(Configuration job, Path dirPath) throws IOExce
   }
 
   public static List<TezTask> getTezTasks(List<Task<? extends Serializable>> tasks) {
-    List<TezTask> tezTasks = new ArrayList<TezTask>();
-    if (tasks != null) {
-      Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends Serializable>>();
-      while (!tasks.isEmpty()) {
-        tasks = getTezTasks(tasks, tezTasks, visited);
-      }
-    }
-    return tezTasks;
-  }
-
-  private static List<Task<? extends Serializable>> getTezTasks(
-          List<Task<? extends Serializable>> tasks,
-          List<TezTask> tezTasks,
-          Set<Task<? extends Serializable>> visited) {
-    List<Task<? extends Serializable>> childTasks = new ArrayList<>();
-    for (Task<? extends Serializable> task : tasks) {
-      if (visited.contains(task)) {
-        continue;
-      }
-      if (task instanceof TezTask && !tezTasks.contains(task)) {
-        tezTasks.add((TezTask) task);
-      }
-
-      if (task.getDependentTasks() != null) {
-        childTasks.addAll(task.getDependentTasks());
-      }
-      visited.add(task);
-    }
-    return childTasks;
+    return getTasks(tasks, TezTask.class);
   }
 
   public static List<SparkTask> getSparkTasks(List<Task<? extends Serializable>> tasks) {
-    List<SparkTask> sparkTasks = new ArrayList<SparkTask>();
-    if (tasks != null) {
-      Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends Serializable>>();
-      while (!tasks.isEmpty()) {
-        tasks = getSparkTasks(tasks, sparkTasks, visited);
-      }
-    }
-    return sparkTasks;
+    return getTasks(tasks, SparkTask.class);
   }
 
-  private static List<Task<? extends Serializable>> getSparkTasks(
-          List<Task<? extends Serializable>> tasks,
-          List<SparkTask> sparkTasks,
-          Set<Task<? extends Serializable>> visited) {
-    List<Task<? extends Serializable>> childTasks = new ArrayList<>();
-    for (Task<? extends Serializable> task : tasks) {
-      if (visited.contains(task)) {
-        continue;
-      }
-      if (task instanceof SparkTask && !sparkTasks.contains(task)) {
-        sparkTasks.add((SparkTask) task);
-      }
-
-      if (task.getDependentTasks() != null) {
-        childTasks.addAll(task.getDependentTasks());
-      }
-      visited.add(task);
-    }
-    return childTasks;
+  public static List<ExecDriver> getMRTasks(List<Task<? extends Serializable>> tasks) {
+    return getTasks(tasks, ExecDriver.class);
   }
 
-  public static List<ExecDriver> getMRTasks(List<Task<? extends Serializable>> tasks) {
-    List<ExecDriver> mrTasks = new ArrayList<ExecDriver>();
+  @SuppressWarnings("unchecked")
+  public static <T> List<T> getTasks(List<Task<? extends Serializable>> tasks, Class<T> requiredType) {
+    List<T> typeSpecificTasks = new ArrayList<>();
     if (tasks != null) {
-      Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends Serializable>>();
+      Set<Task<? extends Serializable>> visited = new HashSet<>();
       while (!tasks.isEmpty()) {
-        tasks = getMRTasks(tasks, mrTasks, visited);
-      }
-    }
-    return mrTasks;
-  }
-
-  private static List<Task<? extends Serializable>> getMRTasks(
-          List<Task<? extends Serializable>> tasks,
-          List<ExecDriver> mrTasks,
-          Set<Task<? extends Serializable>> visited) {
-    List<Task<? extends Serializable>> childTasks = new ArrayList<>();
-    for (Task<? extends Serializable> task : tasks) {
-      if (visited.contains(task)) {
-        continue;
-      }
-      if (task instanceof ExecDriver && !mrTasks.contains(task)) {
-        mrTasks.add((ExecDriver) task);
-      }
-
-      if (task.getDependentTasks() != null) {
-        childTasks.addAll(task.getDependentTasks());
+        List<Task<? extends Serializable>> childTasks = new ArrayList<>();
+        for (Task<? extends Serializable> task : tasks) {
+          if (visited.contains(task)) {
+            continue;
+          }
+          if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) {
+            typeSpecificTasks.add((T) task);
+          }
+          if (task.getDependentTasks() != null) {
+            childTasks.addAll(task.getDependentTasks());
+          }
+          visited.add(task);
+        }
+        // start recursion
+        tasks = childTasks;
       }
-      visited.add(task);
     }
-    return childTasks;
+    return typeSpecificTasks;
   }
 
   /**
